from Voicelab.pipeline.Node import Node
from parselmouth.praat import call
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler

import pandas as pd

from Voicelab.toolkits.Voicelab.VoicelabNode import VoicelabNode
from Voicelab.toolkits.Voicelab.MeasureJitterNode import MeasureJitterNode

###################################################################################################
# MEASURE JITTER PCA NODE
# WARIO pipeline node for performing principle component analysis on the jitter of a voice.
###################################################################################################
# ARGUMENTS
# 'voice'   : sound file generated by parselmouth praat
###################################################################################################
# RETURNS
###################################################################################################


class MeasureJitterPCANode(VoicelabNode):
    def __init__(self, *args, **kwargs):

        """
        Args:
            *args:
            **kwargs:
        """
        super().__init__(*args, **kwargs)

        self.args = {
            "local_jitter": lambda voice: self.measure_jitter(voice)["Local Jitter"],
            "localabsolute_jitter": lambda voice: self.measure_jitter(voice)[
                "Local Absolute Jitter"
            ],
            "rap_jitter": lambda voice: self.measure_jitter(voice)["RAP Jitter"],
            "ppq5_jitter": lambda voice: self.measure_jitter(voice)["ppq5 Jitter"],
            "ddp_jitter": lambda voice: self.measure_jitter(voice)["ddp Jitter"],
        }

        self.cached = {}

    ###############################################################################################
    # process: WARIO hook called once for each voice file.
    ###############################################################################################

    def process(self):

        voice = self.args["voice"]

        local_jitter = self.args["local_jitter"](voice)
        local_abs_jitter = (self.args["localabsolute_jitter"](voice),)
        rap_jitter = (self.args["rap_jitter"](voice),)
        ppq5_jitter = (self.args["ppq5_jitter"](voice),)
        ddp_jitter = self.args["ddp_jitter"](voice)

        jitter_values = {
            "local_jitter": local_jitter,
            "local_abs_jitter": local_abs_jitter,
            "rap_jitter": rap_jitter,
            "ppq5_jitter": ppq5_jitter,
            "ddp_jitter": ddp_jitter,
        }

        jitter_df = pd.DataFrame(data=jitter_values, index=[0])

        # df = self.args['df'] # dataframe
        try:
            measures = [
                "localJitter",
                "localabsoluteJitter",
                "rapJitter",
                "ppq5Jitter",
                "ddpJitter",
            ]
            x = jitter_df
            x = StandardScaler().fit_transform(x)

            # Run the PCA
            pca = PCA(n_components=1)
            principal_components = pca.fit_transform(x)
            jitter_pca_df = pd.DataFrame(
                data=principal_components, columns=["JitterPCA"]
            )

            return {"jitter_pca_df": jitter_pca_df.values.tolist()}

        except:
            print(
                "Jitter PCA failed.  Please check your data"
            )  # todo make this an error message
            return {"jitter_pca_df": None}

    # I want some way to run this node without needing any upstream nodes (except loading a voice)
    # but I also want to optionally pass these values in. This may conflict with our desire to keep
    # each node suitably atomic in it's behaviour
    def measure_jitter(self, voice):
        """
        Args:
            voice:
        """
        if voice not in self.cached:
            measure_jitter = MeasureJitterNode("Measure Shimmer")
            measure_jitter.args["voice"] = voice
            results = measure_jitter.process()
            self.cached = {voice: results}
            return results
        return self.cached[voice]
